import numpy as np
from numpy.testing import assert_equal
from nose.tools import raises
from menpofit.transform import DifferentiableAffine, DifferentiableSimilarity


jac_solution2d = np.array(
    [[[0.,  0.],
    [0.,  0.],
    [0.,  0.],
    [0.,  0.],
    [1.,  0.],
    [0.,  1.]],
    [[0.,  0.],
    [0.,  0.],
    [1.,  0.],
    [0.,  1.],
    [1.,  0.],
    [0.,  1.]],
    [[0.,  0.],
    [0.,  0.],
    [2.,  0.],
    [0.,  2.],
    [1.,  0.],
    [0.,  1.]],
    [[1.,  0.],
    [0.,  1.],
    [0.,  0.],
    [0.,  0.],
    [1.,  0.],
    [0.,  1.]],
    [[1.,  0.],
    [0.,  1.],
    [1.,  0.],
    [0.,  1.],
    [1.,  0.],
    [0.,  1.]],
    [[1.,  0.],
    [0.,  1.],
    [2.,  0.],
    [0.,  2.],
    [1.,  0.],
    [0.,  1.]]])

jac_solution3d = np.array(
    [[[0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.]],
    [[0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.]],
    [[0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.]],
    [[0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.]],
    [[0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [2.,  0.,  0.],
    [0.,  2.,  0.],
    [0.,  0.,  2.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.]],
    [[0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [2.,  0.,  0.],
    [0.,  2.,  0.],
    [0.,  0.,  2.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.]],
    [[1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.]],
    [[1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.]],
    [[1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.]],
    [[1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.]],
    [[1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [2.,  0.,  0.],
    [0.,  2.,  0.],
    [0.,  0.,  2.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [0.,  0.,  0.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.]],
    [[1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [2.,  0.,  0.],
    [0.,  2.,  0.],
    [0.,  0.,  2.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.],
    [1.,  0.,  0.],
    [0.,  1.,  0.],
    [0.,  0.,  1.]]])

sim_jac_solution2d = np.array([[[0.,  0.],
                              [0.,  0.],
                              [1.,  0.],
                              [0.,  1.]],
                              [[0.,  1.],
                              [-1.,  0.],
                              [1.,  0.],
                              [0.,  1.]],
                              [[0.,  2.],
                              [-2.,  0.],
                              [1.,  0.],
                              [0.,  1.]],
                              [[1.,  0.],
                              [0.,  1.],
                              [1.,  0.],
                              [0.,  1.]],
                              [[1.,  1.],
                              [-1.,  1.],
                              [1.,  0.],
                              [0.,  1.]],
                              [[1.,  2.],
                              [-2.,  1.],
                              [1.,  0.],
                              [0.,  1.]]])


def test_affine_jacobian_2d_with_positions():
    params = np.array([0, 0.1, 0.2, 0, 30, 70])
    t = DifferentiableAffine.init_identity(2).from_vector(params)
    explicit_pixel_locations = np.array(
        [[0, 0],
        [0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [1, 2]])
    dW_dp = t.d_dp(explicit_pixel_locations)
    assert_equal(dW_dp, jac_solution2d)


def test_affine_jacobian_3d_with_positions():
    params = np.ones(12)
    t = DifferentiableAffine.init_identity(3).from_vector(params)
    explicit_pixel_locations = np.array(
        [[0, 0, 0],
        [0, 0, 1],
        [0, 1, 0],
        [0, 1, 1],
        [0, 2, 0],
        [0, 2, 1],
        [1, 0, 0],
        [1, 0, 1],
        [1, 1, 0],
        [1, 1, 1],
        [1, 2, 0],
        [1, 2, 1]])
    dW_dp = t.d_dp(explicit_pixel_locations)
    assert_equal(dW_dp, jac_solution3d)


def test_similarity_jacobian_2d():
    params = np.ones(4)
    t = DifferentiableSimilarity.init_identity(2).from_vector(params)
    explicit_pixel_locations = np.array(
        [[0, 0],
        [0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [1, 2]])
    dW_dp = t.d_dp(explicit_pixel_locations)
    assert_equal(dW_dp, sim_jac_solution2d)


@raises(ValueError)
def test_similarity_jacobian_3d_raises_dimensionalityerror():
    t = DifferentiableSimilarity(np.eye(4))
    t.d_dp(np.ones([2, 3]))


@raises(ValueError)
def test_similarity_jacobian_2d_raises_dimensionalityerror():
    params = np.ones(4)
    t = DifferentiableSimilarity.init_identity(2).from_vector(params)
    t.d_dp(np.ones([2, 3]))
